# This file is part of datacube-ows, part of the Open Data Cube project.
# See https://opendatacube.org for more information.
#
# Copyright (c) 2017-2023 OWS Contributors
# SPDX-License-Identifier: Apache-2.0
import xarray

from datacube_ows.startup_utils import initialise_ignorable_warnings
from datacube_ows.styles.base import StandaloneProductProxy, StyleDefBase

initialise_ignorable_warnings()


def StandaloneStyle(cfg):
    """
    Construct a OWS style object that stands alone, independent of a complete OWS configuration environment.

    :param cfg: A valid OWS Style definition configuration dictionary.

        Refer to the documentation for the valid syntax:

        https://datacube-ows.readthedocs.io/en/latest/cfg_styling.html

    :return: A OWS Style Definition object, prepared to work in standalone mode.
    """
    style = StyleDefBase(StandaloneProductProxy(), cfg, stand_alone=True)
    style.make_ready(None)
    return style


def apply_ows_style(style, data, loop_over=None, valid_data_mask=None):
    """
    Apply an OWS style to an ODC XArray to generate a styled image.

    :param style: An OWS Style object, as created by StandaloneStyle()
    :param data: An xarray Dataset, as generated by datacube.load_data()
            Note that the Dataset must contain all of the band names referenced by the standalone style
            configuration.  (The names of the data variables in the dataset must exactly match
            the band names in the configuration.  None of the band aliasing techniques normally
            supported by OWS can work in standalone mode.)
            For bands that are used as bitmaps (i.e. either for masking with pq_mask or colour coding
            in value_map), the data_variable must have a valid flag_definition attribute.
    :param loop_over: (optional) A string which is the name of a dimension in the data to loop over,
            for bulk processing.  E.g. if set to "time", the output will have the same time dimension
            coordinates as the input, with the single-date style being applied to each time slice
            in the input data independently.
    :param valid_data_mask: (optional) An xarray DataArray mask, with dimensions and coordinates matching data.
    :return: An xarray Dataset, with the same dimensions and coordinates as data, and four data_vars of
            8 bit signed integer data named red, green, blue and alpha, representing an 24bit RGBA image.
    """
    if loop_over is None:
        return style.transform_data(
                data,
                style.to_mask(
                        data,
                        valid_data_mask
                )
        )
    image_slices = []
    for coord in data[loop_over].values:
        d_slice = data.sel(**{loop_over: coord})
        image_slices.append(
            style.transform_data(
                d_slice,
                style.to_mask(
                    d_slice,
                    valid_data_mask
                )
            )
        )
    return xarray.concat(image_slices, data[loop_over])


def apply_ows_style_cfg(cfg, data, loop_over=None, valid_data_mask=None):
    """
    Apply an OWS style configuration to an ODC XArray to generate a styled image.

    :param cfg: A valid OWS Style definition configuration dictionary.

        Refer to the documentation for the valid syntax:

        https://datacube-ows.readthedocs.io/en/latest/cfg_styling.html
    :param data: An xarray Dataset, as generated by datacube.load_data()
            Note that the Dataset must contain all of the band names referenced by the standalone style
            configuration.  (The names of the data variables in the dataset must exactly match
            the band names in the configuration.  None of the band aliasing techniques normally
            supported by OWS can work in standalone mode.)
            For bands that are used as bitmaps (i.e. either for masking with pq_mask or colour coding
            in value_map), the data_variable must have a valid flag_definition attribute.
    :param loop_over: (optional) A string which is the name of a dimension in the data to loop over,
            for bulk processing.  E.g. if set to "time", the output will have the same time dimension
            coordinates as the input, with the single-date style being applied to each time slice
            in the input data independently.
    :param valid_data_mask: (optional) An xarray DataArray mask, with dimensions and coordinates matching data.
    :return: An xarray Dataset, with the same dimensions and coordinates as data, and four data_vars of
            8 bit signed integer data named red, green, blue and alpha, representing an 24bit RGBA image.
    """
    return apply_ows_style(
        StandaloneStyle(cfg),
        data,
        loop_over=loop_over,
        valid_data_mask=valid_data_mask
    )


def generate_ows_legend_style(style, ndates=0):
    """
    Generate a legend image for a style

    :param style: An OWS Style object, as created by StandaloneStyle()
    :param ndates: (optional) Number of dates (for styles with multi-date handlers)
    :return: A PIL Image object.
    """
    return style.render_legend(ndates)


def generate_ows_legend_style_cfg(cfg, ndates=0):
    """
    Generate a legend image for a style configuration

    :param cfg: A valid OWS Style definition configuration dictionary.

        Refer to the documentation for the valid syntax:

        https://datacube-ows.readthedocs.io/en/latest/cfg_styling.html
    :param ndates: (optional) Number of dates (for styles with multi-date handlers)
    :return: A PIL Image object.
    """
    return generate_ows_legend_style(StandaloneStyle(cfg), ndates)


def plot_image(xr_image, x="x", y="y", size=10, aspect=None):
    """
    Plot an Xarray image with matplotlib. (e.g. for display in JupyterHub)

    :param xr_image: An xarray image, as returned by the "apply_ows_style" functions
    :param x: The name of the dimension to be plotted horizontally (optional, defaults to "x")
    :param y: The name of the dimension to be plotted vertically (optional, defaults to "y")
    :param size: The height of the plotted image, in inches (optional, defaults to 10)
    :param aspect: The aspect ratio of the plotted image (width/height of plotted image).
                (defaults to None, which means use the aspect ratio of the data.)
    """
    width = len(xr_image[x])
    height = len(xr_image[y])
    aspect = width / height
    rgb = xr_image.to_array(dim="color")
    rgb = rgb.transpose(*(rgb.dims[1:] + rgb.dims[:1]))
    rgb = rgb / 255
    rgb.plot.imshow(x=x, y=y, size=size, aspect=aspect)


def plot_image_with_style(style, data, x="x", y="y", size=10, aspect=None, valid_data_mask=None):
    """
    Apply an OWS style to some data, and display with matplotlib. (e.g. for display in JupyterHub)

    :param style: An OWS Style object, as created by StandaloneStyle()
    :param data: An xarray Dataset, as generated by datacube.load_data()
            Note that the Dataset must contain all of the band names referenced by the standalone style
            configuration.  (The names of the data variables in the dataset must exactly match
            the band names in the configuration.  None of the band aliasing techniques normally
            supported by OWS can work in standalone mode.)
            For bands that are used as bitmaps (i.e. either for masking with pq_mask or colour coding
            in value_map), the data_variable must have a valid flag_definition attribute.
    :param x: The name of the dimension to be plotted horizontally (optional, defaults to "x")
    :param y: The name of the dimension to be plotted vertically (optional, defaults to "y")
    :param size: The height of the plotted image, in inches (optional, defaults to 10)
    :param aspect: The aspect ratio of the plotted image (width/height of plotted image).
                (defaults to None, which means use the aspect ratio of the data.)
    :param valid_data_mask: (optional) An xarray DataArray mask, with dimensions and coordinates matching data.
    """
    plot_image(apply_ows_style(style, data, valid_data_mask=valid_data_mask), x=x, y=y, size=size, aspect=aspect)


def plot_image_with_style_cfg(cfg, data, x="x", y="y", size=10, aspect=None, valid_data_mask=None):
    """
    Apply an OWS style to some data, and display with matplotlib. (e.g. for display in JupyterHub)

    :param cfg: A valid OWS Style definition configuration dictionary.

        Refer to the documentation for the valid syntax:

        https://datacube-ows.readthedocs.io/en/latest/cfg_styling.html
    :param data: An xarray Dataset, as generated by datacube.load_data()
            Note that the Dataset must contain all of the band names referenced by the standalone style
            configuration.  (The names of the data variables in the dataset must exactly match
            the band names in the configuration.  None of the band aliasing techniques normally
            supported by OWS can work in standalone mode.)
            For bands that are used as bitmaps (i.e. either for masking with pq_mask or colour coding
            in value_map), the data_variable must have a valid flag_definition attribute.
    :param x: The name of the dimension to be plotted horizontally (optional, defaults to "x")
    :param y: The name of the dimension to be plotted vertically (optional, defaults to "y")
    :param size: The height of the plotted image, in inches (optional, defaults to 10)
    :param aspect: The aspect ratio of the plotted image (width/height of plotted image).
                (defaults to None, which means use the aspect ratio of the data.)
    :param valid_data_mask: (optional) An xarray DataArray mask, with dimensions and coordinates matching data.
    """
    plot_image(apply_ows_style_cfg(cfg, data, valid_data_mask=valid_data_mask), x=x, y=y, size=size, aspect=aspect)
